# -*- coding: utf-8 -*-
# @Time    : 2023/3/16 19:48
# @Author  : xiehou
# @File    : inference.py
# @Software: PyCharm
from utils.tokenize import get_tokenizer
from models.casrel import Casrel
import config
import argparse
import torch
import numpy as np
import json

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='Casrel', help='name of the pretrained_model')
parser.add_argument('--lr', type=float, default=1e-5)
parser.add_argument('--multi_gpu', type=bool, default=False)
parser.add_argument('--dataset', type=str, default='CMED')
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--max_epoch', type=int, default=200)
parser.add_argument('--test_epoch', type=int, default=1)
parser.add_argument('--train_prefix', type=str, default='train_triples')
parser.add_argument('--dev_prefix', type=str, default='dev_triples')
parser.add_argument('--test_prefix', type=str, default='test_triples')
parser.add_argument('--max_len', type=int, default=150)
parser.add_argument('--rel_num', type=int, default=44)
parser.add_argument('--period', type=int, default=50)
parser.add_argument('--debug', type=bool, default=False)
args = parser.parse_args()

con = config.Config(args)

# subject的头尾位置阈值设定
h_bar = 0.5
t_bar = 0.5

texts = [
    "登革热 @ 大约 90% 的 登革出血热 ( dengue haemorrhagic fever , DHF ) 病例 为 5 岁 以下 儿童 。 登革热 @ 典型 的 登革热 更 常 见于 成人 ， 而 非 儿童 。",
    "子宫内膜癌 @ 阴道 短 距离 放疗 和 EBRT 在 控制 阴道 疾病 方面 一样 有效 ， 但 有 更 少 的 胃肠道 副作用 （ GOG-99 、 PORTEC-1 、 PORTEC-2 ） 。"]

# init tokenizer
tokenizer = get_tokenizer('./pretrained_model/chinese_bert_wwm/vocab.txt')
id2rel = json.load(open('./data/CMED/rel2id.json'))[0]

# inti model
model_path = './c'
model = Casrel(config=con)
model.load_state_dict(torch.load(model_path))
model.cuda()
model.eval()

for text in texts:
    # data process
    tokens = tokenizer.tokenize(text)
    text_len = len(tokens)
    token_ids, segment_ids = tokenizer.encode(first=text)
    masks = segment_ids
    if len(token_ids) > text_len:
        token_ids = token_ids[:text_len]
        masks = masks[:text_len]
    token_ids = np.array(token_ids)
    masks = np.array(masks) + 1
    max_text_len = text_len + 2
    batch_token_ids = torch.LongTensor(1, max_text_len).zero_()
    batch_masks = torch.LongTensor(1, max_text_len).zero_()
    batch_token_ids[0, :text_len].copy_(torch.from_numpy(token_ids))
    batch_masks[0, :text_len].copy_(torch.from_numpy(masks))

    # start predict
    encoded_text = model.get_encoded_text(batch_token_ids, batch_masks)
    pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)
    sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], np.where(pred_sub_tails.cpu()[0] > t_bar)[0]
    subjects = []
    for sub_head in sub_heads:
        sub_tail = sub_tails[sub_tails >= sub_head]
        if len(sub_tail) > 0:
            sub_tail = sub_tail[0]
            subject = tokens[sub_head: sub_tail]
            subjects.append((subject, sub_head, sub_tail))
    if subjects:
        triple_list = []
        # [subject_num, seq_len, bert_dim]
        repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1)
        # [subject_num, 1, seq_len]
        sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
        sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
        for subject_idx, subject in enumerate(subjects):
            sub_head_mapping[subject_idx][0][subject[1]] = 1
            sub_tail_mapping[subject_idx][0][subject[2]] = 1
        sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text)
        sub_head_mapping = sub_head_mapping.to(repeated_encoded_text)
        pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping,
                                                                         repeated_encoded_text)
        for subject_idx, subject in enumerate(subjects):
            sub = subject[0]
            sub = ''.join([i.lstrip("##") for i in sub])
            sub = ' '.join(sub.split('[unused1]'))
            obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where(
                pred_obj_tails.cpu()[subject_idx] > t_bar)
            for obj_head, rel_head in zip(*obj_heads):
                for obj_tail, rel_tail in zip(*obj_tails):
                    if obj_head <= obj_tail and rel_head == rel_tail:
                        rel = id2rel[str(int(rel_head))]
                        obj = tokens[obj_head: obj_tail]
                        obj = ''.join([i.lstrip("##") for i in obj])
                        obj = ' '.join(obj.split('[unused1]'))
                        triple_list.append((sub, rel, obj))
                        break
        triple_set = set()
        for s, r, o in triple_list:
            triple_set.add((s, r, o))
        pred_list = list(triple_set)
    else:
        pred_list = []

    pred_triples = set(pred_list)
    print(text)
    for triple in pred_triples:
        print('subject:{},relation:{},object:{}'.format(triple[0], triple[1], triple[2]))
    print('-*' * 50)
